// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>
#include "dataflow_api.h"
#include <vector>

#include "ttnn/operations/transformer/sdpa_decode/device/kernels/rt_args_common.hpp"
#include "dataflow_common.hpp"
// #include "debug/dprint.h"

void kernel_main() {
    /*
    In DRAM, Q is (B, PNHt, DHt), K is (B, St, DHt), V is (B, St, DHt), mask is (B, PNHt, PSt)
    We want to read for a particular batch cur_batch, and sequence length up to padded layer length.
    We read Q: (cur_batch, PNHt, DHt), K: (cur_batch, PSt, DHt), V: (cur_batch, PSt, DHt), mask: (cur_batch, PNHt, PSt)
    */
    constexpr uint32_t B = get_compile_time_arg_val(0);           // batch size
    constexpr uint32_t PNHt = get_compile_time_arg_val(1);        // padded number of heads in tiles
    constexpr uint32_t St = get_compile_time_arg_val(2);          // full sequence length of kv cache in tiles
    constexpr uint32_t DHt = get_compile_time_arg_val(3);         // head dim
    constexpr uint32_t vDHt = get_compile_time_arg_val(4);        // head dim of V
    constexpr uint32_t Sk_chunk_t = get_compile_time_arg_val(5);  // number of tiles in seqlen of a k/v/mask chunk
    constexpr uint32_t num_cores = get_compile_time_arg_val(6);
    constexpr bool is_q_sharded = get_compile_time_arg_val(7);
    constexpr uint32_t num_cores_per_batch = get_compile_time_arg_val(8);
    constexpr uint32_t k_chunk_size = get_compile_time_arg_val(9);
    constexpr uint32_t index_stick_size_B = get_compile_time_arg_val(10);
    constexpr bool is_paged_attention = get_compile_time_arg_val(11) == 1;
    constexpr uint32_t num_kv_heads = get_compile_time_arg_val(12);
    constexpr uint32_t block_size_t = get_compile_time_arg_val(13);
    constexpr uint32_t Bkv = get_compile_time_arg_val(14);
    constexpr uint32_t q_heads_parallel_factor = get_compile_time_arg_val(15);
    constexpr uint32_t num_cores_per_head = get_compile_time_arg_val(16);
    constexpr uint32_t num_heads_per_core = get_compile_time_arg_val(17);
    constexpr uint32_t num_output_cores = get_compile_time_arg_val(18);
    constexpr bool is_causal = get_compile_time_arg_val(19) == 1;
    constexpr bool use_attention_mask = get_compile_time_arg_val(20) == 1;
    constexpr bool use_attention_sink = get_compile_time_arg_val(21) == 1;
    constexpr uint32_t max_dynamic_chunk_size = get_compile_time_arg_val(22);
    constexpr bool tilize_q = get_compile_time_arg_val(23) == 1;
    constexpr bool reuse_k = get_compile_time_arg_val(24) == 1;
    constexpr bool use_half_tile = get_compile_time_arg_val(25);
    constexpr uint32_t q_chunk_size_bytes = get_compile_time_arg_val(26);
    constexpr bool is_cur_pos_tensor_sharded = get_compile_time_arg_val(27);
    constexpr bool is_page_table_sharded = get_compile_time_arg_val(28);
    constexpr uint32_t q_page_size_bytes = get_compile_time_arg_val(29);

    constexpr auto k_args = TensorAccessorArgs<30>();
    constexpr auto q_args = TensorAccessorArgs<k_args.next_compile_time_args_offset()>();
    constexpr auto v_args = TensorAccessorArgs<q_args.next_compile_time_args_offset()>();
    constexpr auto mask_args = TensorAccessorArgs<v_args.next_compile_time_args_offset()>();
    constexpr auto pos_args = TensorAccessorArgs<mask_args.next_compile_time_args_offset()>();
    constexpr auto page_table_args = TensorAccessorArgs<pos_args.next_compile_time_args_offset()>();
    constexpr auto attention_sink_args = TensorAccessorArgs<page_table_args.next_compile_time_args_offset()>();

    uint32_t arg_idx = 0;
    const uint32_t q_addr = get_arg_val<uint32_t>(arg_idx++);
    const uint32_t k_addr = get_arg_val<uint32_t>(arg_idx++);
    const uint32_t v_addr = get_arg_val<uint32_t>(arg_idx++);
    const uint32_t pos_addr = get_arg_val<uint32_t>(arg_idx++);
    const uint32_t page_table_addr = get_arg_val<uint32_t>(arg_idx++);
    const uint32_t mask_addr = get_arg_val<uint32_t>(arg_idx++);
    const uint32_t attention_sink_addr = get_arg_val<uint32_t>(arg_idx++);
    const uint32_t page_table_page_size = get_arg_val<uint32_t>(arg_idx++);
    const bool is_worker = get_arg_val<uint32_t>(arg_idx++) == 0;
    const bool is_output_core = get_arg_val<uint32_t>(arg_idx++) == 1;
    const uint32_t cur_head_group = get_arg_val<uint32_t>(arg_idx++);
    const uint32_t cur_batch = get_arg_val<uint32_t>(arg_idx++);
    const uint32_t core_num_in_reduce = get_arg_val<uint32_t>(arg_idx++);
    const uint32_t core_num_in_output = get_arg_val<uint32_t>(arg_idx++);
    const uint32_t cur_pos_arg = get_arg_val<uint32_t>(arg_idx++);

    // idle core
    if (q_addr == 0) {
        return;
    }
    // Get cur_pos
    constexpr uint32_t cur_pos_base = St * 32 - 1;
    uint32_t cur_pos = cur_pos_base;  // default to non-causal, which we do attention on the entire kv cache. In this
                                      // case we set cur_pos to the last position
    if constexpr (is_causal) {
        // using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list
        if (cur_pos_arg != UINT32_MAX) {
            cur_pos = cur_pos_arg;
        } else {
            constexpr uint32_t cb_index_id = tt::CBIndex::c_8;
            cb_reserve_back(cb_index_id, 1);
            uint32_t index_cb_wr_ptr = get_write_ptr(cb_index_id);

            if constexpr (!is_cur_pos_tensor_sharded) {
                const auto addrg = TensorAccessor(pos_args, pos_addr, index_stick_size_B);

                // index_tensor has one page to read
                uint64_t tensor_index_noc_addr = addrg.get_noc_addr(0);
                noc_async_read(tensor_index_noc_addr, index_cb_wr_ptr, index_stick_size_B);
                noc_async_read_barrier();
            }

            cb_push_back(cb_index_id, 1);
            volatile tt_l1_ptr uint32_t* index_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(index_cb_wr_ptr);
            cur_pos = index_ptr[cur_batch / q_heads_parallel_factor];
        }

        if (cur_pos == UINT32_MAX) {
            // cur_pos of -1 indicates that the user should be skipped
            return;
        }
    }

    auto Sk_chunk_t_dynamic = get_dynamic_Sk_chunk_t<Sk_chunk_t, max_dynamic_chunk_size>(cur_pos);
    auto k_chunk_size_dynamic = Sk_chunk_t_dynamic * tt::constants::TILE_HEIGHT;

    // Sequence length assignment
    auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end] =
        get_runtime_args(cur_pos, cur_batch, core_num_in_reduce, num_cores_per_head, k_chunk_size_dynamic);

    if (k_chunk_start == k_chunk_end) {
        return;  // early exit because no computes needs to be done
    }

    tt_l1_ptr uint32_t* all_output_noc_x = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx));
    arg_idx += num_output_cores;
    tt_l1_ptr uint32_t* all_output_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx++));

    uint32_t output_core_noc_x = all_output_noc_x[cur_batch];
    uint32_t output_core_noc_y = all_output_noc_y[cur_batch];

    constexpr uint32_t q_chunk_tiles = PNHt * DHt;
    uint32_t k_chunk_tiles = Sk_chunk_t_dynamic * DHt;
    uint32_t v_chunk_tiles = Sk_chunk_t_dynamic * vDHt;
    uint32_t mask_chunk_tiles = PNHt * Sk_chunk_t_dynamic;

    constexpr uint32_t cb_q_in = tt::CBIndex::c_0;
    constexpr uint32_t cb_q_rm = tt::CBIndex::c_10;
    constexpr uint32_t cb_k_in = tt::CBIndex::c_1;
    constexpr uint32_t cb_v_in = tt::CBIndex::c_2;
    constexpr uint32_t cb_mask_in = tt::CBIndex::c_3;
    constexpr uint32_t cb_attention_sink = tt::CBIndex::c_4;

    constexpr uint32_t onetile = 1;
    constexpr uint32_t q_tile_bytes = get_tile_size(cb_q_in);
    constexpr uint32_t k_tile_bytes = get_tile_size(cb_k_in);
    constexpr uint32_t v_tile_bytes = get_tile_size(cb_v_in);
    constexpr uint32_t mask_tile_bytes = get_tile_size(cb_mask_in);
    constexpr uint32_t attention_sink_tile_bytes = get_tile_size(cb_attention_sink);

    constexpr uint32_t barrier_threshold = get_barrier_read_threshold<q_tile_bytes, num_cores>();
    uint32_t barrier_count = 0;

    // First, read Q entirely, it could be interleaved or sharded
    uint32_t q_batch_offset = cur_batch * q_chunk_tiles;

    if constexpr (is_q_sharded) {
        uint64_t q_read_addr;
        uint32_t q_write_ptr;
        if (is_output_core) {
            q_read_addr = get_noc_addr(q_addr);
        } else {
            q_read_addr = get_noc_addr(output_core_noc_x, output_core_noc_y, q_addr);
        }
        if constexpr (tilize_q) {
            cb_reserve_back(cb_q_rm, q_chunk_tiles);
            q_write_ptr = get_write_ptr(cb_q_rm);
        } else {
            cb_reserve_back(cb_q_in, q_chunk_tiles);
            q_write_ptr = get_write_ptr(cb_q_in);
        }
        if constexpr (use_half_tile and not tilize_q) {
            // q_addr represents 32x32 tiles; read them as 16x32 tiles
            // TODO: Properly setup q input as tiny tiles and remove special handling for tiny tiles
            for (uint8_t tile = 0; tile < q_chunk_tiles; tile++) {
                noc_async_read(q_read_addr, q_write_ptr, q_tile_bytes);
                q_read_addr += 2 * q_tile_bytes;
                q_write_ptr += q_tile_bytes;
            }
        } else {
            noc_async_read(q_read_addr, q_write_ptr, q_chunk_size_bytes);
        }
        noc_async_read_barrier();
        if constexpr (tilize_q) {
            cb_push_back(cb_q_rm, q_chunk_tiles);
        } else {
            cb_push_back(cb_q_in, q_chunk_tiles);
        }
    } else {
        const auto q_reader = TensorAccessor(q_args, q_addr, q_page_size_bytes);
        uint32_t q_tile_id = q_batch_offset;
        cb_reserve_back(cb_q_in, q_chunk_tiles);
        uint32_t q_write_ptr = get_write_ptr(cb_q_in);
        for (uint32_t tile = 0; tile < q_chunk_tiles; ++tile) {
            uint64_t q_read_addr = q_reader.get_noc_addr(q_tile_id);
            noc_async_read(q_read_addr, q_write_ptr, q_tile_bytes);
            q_tile_id += 1;
            q_write_ptr += q_tile_bytes;
            if (++barrier_count == barrier_threshold) {
                noc_async_read_barrier();
                barrier_count = 0;
            }
        }
        noc_async_read_barrier();
        cb_push_back(cb_q_in, q_chunk_tiles);
    }

    // Read the rest
    const auto k_reader = TensorAccessor(k_args, k_addr, k_tile_bytes);

    const auto v_reader = TensorAccessor(v_args, v_addr, v_tile_bytes);

    const auto mask_reader = TensorAccessor(mask_args, mask_addr, mask_tile_bytes);

    // Read attention sink
    if constexpr (use_attention_sink) {
        const auto attention_sink_reader =
            TensorAccessor(attention_sink_args, attention_sink_addr, attention_sink_tile_bytes);

        cb_reserve_back(cb_attention_sink, PNHt);
        uint32_t attention_sink_write_ptr = get_write_ptr(cb_attention_sink);

        for (uint32_t tile = 0; tile < PNHt; ++tile) {
            noc_async_read_tile(tile, attention_sink_reader, attention_sink_write_ptr);
            attention_sink_write_ptr += attention_sink_tile_bytes;
        }
        noc_async_read_barrier();
        cb_push_back(cb_attention_sink, PNHt);
    }

    volatile tt_l1_ptr uint32_t* page_table_ptr;
    uint32_t page_table_cb_wr_ptr = 0;
    // Typed pointers for page table entries in L1
    volatile tt_l1_ptr uint16_t* page_table_ptr_u16 = nullptr;
    volatile tt_l1_ptr uint32_t* page_table_ptr_u32 = nullptr;
    if constexpr (is_paged_attention) {
        constexpr uint32_t cb_id_page_table = tt::CBIndex::c_9;
        uint32_t num_pages_to_read = is_page_table_sharded ? B : 1;
        cb_reserve_back(cb_id_page_table, num_pages_to_read);

        // Read page table from DRAM
        if constexpr (!is_page_table_sharded) {
            page_table_cb_wr_ptr = get_write_ptr(cb_id_page_table);
            const auto page_table_gen = TensorAccessor(page_table_args, page_table_addr, page_table_page_size);
            uint64_t page_table_noc_addr = page_table_gen.get_noc_addr((cur_batch / q_heads_parallel_factor));
            noc_async_read(page_table_noc_addr, page_table_cb_wr_ptr, page_table_page_size);
            noc_async_read_barrier();
            page_table_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(page_table_cb_wr_ptr);
            page_table_ptr_u32 = page_table_ptr;

        } else {  // Read page table from dyanmically allocated L1 buffer
            page_table_cb_wr_ptr =
                get_write_ptr(cb_id_page_table) + (cur_batch / q_heads_parallel_factor) * page_table_page_size;
            page_table_ptr_u16 = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(page_table_cb_wr_ptr);
        }

        cb_push_back(cb_id_page_table, num_pages_to_read);
    }

    for (uint32_t cur_head = cur_head_group * num_heads_per_core;
         cur_head < cur_head_group * num_heads_per_core + num_heads_per_core;
         ++cur_head) {
        const uint32_t mask_batch_offset = ((cur_batch / q_heads_parallel_factor) % Bkv) * PNHt * St;
        const uint32_t mask_chunk_offset = k_chunk_start * Sk_chunk_t_dynamic;
        uint32_t mask_start_tile_id = mask_batch_offset + mask_chunk_offset;
        if constexpr (is_paged_attention) {
            for (uint32_t k_chunk = k_chunk_start; k_chunk < k_chunk_end; ++k_chunk) {
                const uint32_t k_chunk_start_row_num = k_chunk * Sk_chunk_t_dynamic;
                uint64_t k_base_read_ptr;
                {
                    // Read K chunk in row-major order (to simplify page mapping). Write tiles to CB in transposed
                    // order.
                    cb_reserve_back(cb_k_in, k_chunk_tiles);
                    uint32_t k_write_ptr = get_write_ptr(cb_k_in);
                    k_base_read_ptr = get_noc_addr(k_write_ptr);
                    barrier_count = 0;
                    for (uint32_t row = 0; row < Sk_chunk_t_dynamic; ++row) {
                        uint32_t k_write_ptr_col = k_write_ptr + row * k_tile_bytes;
                        uint32_t virtual_k_tile_row_num = k_chunk_start_row_num + row;

                        uint32_t physical_k_tile_id =
                            (is_page_table_sharded)
                                ? virtual_seq_tile_id_to_physical_tile_id<uint16_t, num_kv_heads, block_size_t, DHt>(
                                      virtual_k_tile_row_num, cur_head, page_table_ptr_u16)
                                : virtual_seq_tile_id_to_physical_tile_id<num_kv_heads, block_size_t, DHt>(
                                      virtual_k_tile_row_num, cur_head, page_table_ptr_u32);
                        for (uint32_t col = 0; col < DHt; ++col) {
                            noc_async_read_tile(physical_k_tile_id, k_reader, k_write_ptr_col);
                            physical_k_tile_id += 1;                               // Go to next tile in row
                            k_write_ptr_col += Sk_chunk_t_dynamic * k_tile_bytes;  // Go to next column in CB

                            if (++barrier_count == barrier_threshold) {
                                noc_async_read_barrier();
                                barrier_count = 0;
                            }
                        }
                    }
                    noc_async_read_barrier();
                    cb_push_back(cb_k_in, k_chunk_tiles);
                }

                if constexpr (use_attention_mask) {
                    mask_start_tile_id = read_mask_chunk<cb_mask_in, mask_tile_bytes, barrier_threshold, PNHt>(
                        PSt, Sk_chunk_t_dynamic, mask_chunk_tiles, mask_start_tile_id, mask_reader);
                }

                {
                    if constexpr (reuse_k) {
                        // Read V chunk (tranpose of K), from K's L1 buffer
                        cb_reserve_back(cb_v_in, v_chunk_tiles);
                        uint32_t v_write_ptr = get_write_ptr(cb_v_in);
                        uint64_t k_read_ptr = k_base_read_ptr;

                        for (uint32_t row = 0; row < Sk_chunk_t_dynamic; ++row) {  // Row of V
                            k_read_ptr = k_base_read_ptr + row * k_tile_bytes;     // Increment across K's Col

                            for (uint32_t col = 0; col < vDHt; ++col) {  // Col of V
                                noc_async_read(k_read_ptr, v_write_ptr, v_tile_bytes);

                                v_write_ptr += v_tile_bytes;
                                k_read_ptr += Sk_chunk_t_dynamic * k_tile_bytes;  // Strid across K's width
                            }
                        }
                    } else {
                        // Read V chunk in row major order, write in row-major order
                        cb_reserve_back(cb_v_in, v_chunk_tiles);
                        uint32_t v_write_ptr = get_write_ptr(cb_v_in);
                        barrier_count = 0;

                        for (uint32_t row = 0; row < Sk_chunk_t_dynamic; ++row) {
                            uint32_t virtual_v_tile_row_num = k_chunk_start_row_num + row;
                            uint32_t physical_v_tile_id =
                                (is_page_table_sharded)
                                    ? virtual_seq_tile_id_to_physical_tile_id<
                                          uint16_t,
                                          num_kv_heads,
                                          block_size_t,
                                          DHt>(virtual_v_tile_row_num, cur_head, page_table_ptr_u16)
                                    : virtual_seq_tile_id_to_physical_tile_id<num_kv_heads, block_size_t, DHt>(
                                          virtual_v_tile_row_num, cur_head, page_table_ptr_u32);
                            for (uint32_t col = 0; col < vDHt; ++col) {
                                noc_async_read_tile(physical_v_tile_id, v_reader, v_write_ptr);
                                physical_v_tile_id += 1;
                                v_write_ptr += v_tile_bytes;

                                if (++barrier_count == barrier_threshold) {
                                    noc_async_read_barrier();
                                    barrier_count = 0;
                                }
                            }
                            physical_v_tile_id += (DHt - vDHt);  // Skip the padding!
                        }
                    }

                    noc_async_read_barrier();
                    cb_push_back(cb_v_in, v_chunk_tiles);
                }

            }
        } else {
            // Offset for current batch
            const uint32_t k_batch_offset = ((cur_batch / q_heads_parallel_factor) % Bkv) * num_kv_heads * St * DHt;
            const uint32_t k_head_offset = cur_head * St * DHt;

            // Then, read K, V, Mask k_chunk_tiles at a time
            const uint32_t k_chunk_offset = k_chunk_start * Sk_chunk_t_dynamic * DHt;
            uint32_t k_start_tile_id = k_batch_offset + k_head_offset + k_chunk_offset;

            read_kv_mask_chunks<
                DHt,
                vDHt,
                barrier_threshold,
                mask_tile_bytes,
                PNHt,
                use_attention_mask,
                cb_k_in,
                cb_v_in,
                cb_mask_in,
                reuse_k>(
                k_chunk_start,
                k_chunk_end,
                k_start_tile_id,
                mask_start_tile_id,
                Sk_chunk_t_dynamic,
                k_chunk_tiles,
                v_chunk_tiles,
                mask_chunk_tiles,
                k_reader,
                v_reader,
                mask_reader,
                k_tile_bytes,
                v_tile_bytes,
                PSt);
        }
    }
}
